import os
import sys
import pybedtools
import pysam


dataset, library = sys.argv[1:]

keep_targets = set(['mRNA', 'lncRNA', 'gencode', 'fantomcat', 'genome',
                    'MALAT1', 'TERC', 'RMRP', 'RPPH', 'snhg',
                   ])

skip_targets = set(['chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'yRNA',
                    'histone', 'scaRNA', 'snar', 'vRNA',
                   ])

keep_annotations = set(['FANTOM5_enhancer',
                        'roadmap_enhancer', 'roadmap_dyadic',
                        'novel_enhancer_CAGE', 'novel_enhancer_HiSeq',
                        'sense_proximal', 'prompt', 'antisense',
                        'sense_upstream', 'sense_distal',
                        'sense_distal_upstream',
                        'antisense_distal', 'antisense_distal_upstream',
                       ])

skip_annotations = set(['presnoRNA', 'prescaRNA', 'presnRNA', 'pretRNA'])




def parse_lines(dataset, library):
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Mapping" % dataset
    filename = "%s.bam" % library
    path = os.path.join(directory, filename)
    print("Reading", path)
    alignments = pysam.AlignmentFile(path)
    if dataset in ("HiSeq", "CAGE", "StartSeq"):
        yield from alignments
    elif dataset == "MiSeq":
        for line1 in alignments:
            line2 = next(alignments)
            yield line1
    else:
        raise Exception("Unknown dataset %s" % dataset)
    alignments.close()

def analyze_bamfile(dataset, library):
    current = None
    lines = parse_lines(dataset, library)
    for line in lines:
        if line.is_unmapped:
            continue
        target = line.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            annotation = line.get_tag("XA")
        except KeyError:
            pass
        else:
            if annotation in skip_annotations:
                continue
            assert annotation in keep_annotations
        multimap = line.get_tag("NH")
        if multimap != 1:
            continue
        if line.is_reverse:
            strand = '-'
            start = line.aend - 1
        else:
            strand = '+'
            start = line.pos
        chromosome = line.reference_name
        if (chromosome, start, strand) != current:
            if current is not None:
                target = ",".join(sorted(targets))
                score = str(count)
                fields = [current[0],       # chromosome
                          current[1],       # start
                          current[1] + 1,   # end
                          target,           # name
                          score,            # score
                          current[2]]       # strand
                interval = pybedtools.create_interval_from_list(fields)
                yield interval
            targets = set()
            count = 0.0
        current = chromosome, start, strand
        target = line.get_tag("XT")
        targets.add(target)
        count += 1.0
    target = ",".join(sorted(targets))
    score = str(count)
    fields = [current[0],       # chromosome
              current[1],       # start
              current[1] + 1,   # end
              target,           # name
              score,            # score
              current[2]]       # strand
    interval = pybedtools.create_interval_from_list(fields)
    yield interval


alignments = analyze_bamfile(dataset, library)
alignments = pybedtools.BedTool(alignments)
alignments = alignments.saveas()
print("Sorting")
alignments = alignments.sort()
print("Merging")
alignments = alignments.merge(s=True, d=-1, c=(4,5,6), o=('distinct','sum','distinct'))

filename = "%s.ctss.bed" % library
print("Writing", filename)
alignments.saveas(filename)
